/**
 * Copyright 2017-2023 by XGBoost contributors
 */
#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_

#include <dmlc/omp.h>

#include <cmath>

#include "../common/math.h"
#include "xgboost/data.h"  // MetaInfo
#include "xgboost/logging.h"
#include "xgboost/task.h"  // ObjInfo

namespace xgboost::obj {
// linear regression
struct LinearSquareLoss {
  XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return x; }
  XGBOOST_DEVICE static bool CheckLabel(bst_float) { return true; }
  XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
    return predt - label;
  }
  XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float, bst_float) { return 1.0f; }
  static bst_float ProbToMargin(bst_float base_score) { return base_score; }
  static const char* LabelErrorMsg() { return ""; }
  static const char* DefaultEvalMetric() { return "rmse"; }

  static const char* Name() { return "reg:squarederror"; }
  static ObjInfo Info() { return {ObjInfo::kRegression, true, false}; }
};

struct SquaredLogError {
  XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return x; }
  XGBOOST_DEVICE static bool CheckLabel(bst_float label) { return label > -1; }
  XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
    predt = fmaxf(predt, -1 + 1e-6);  // ensure correct value for log1p
    return (std::log1p(predt) - std::log1p(label)) / (predt + 1);
  }
  XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
    predt = fmaxf(predt, -1 + 1e-6);
    float res = (-std::log1p(predt) + std::log1p(label) + 1) / std::pow(predt + 1, 2);
    res = fmaxf(res, 1e-6f);
    return res;
  }
  static bst_float ProbToMargin(bst_float base_score) { return base_score; }
  static const char* LabelErrorMsg() {
    return "label must be greater than -1 for rmsle so that log(label + 1) can be valid.";
  }
  static const char* DefaultEvalMetric() { return "rmsle"; }

  static const char* Name() { return "reg:squaredlogerror"; }

  static ObjInfo Info() { return ObjInfo::kRegression; }
};

// logistic loss for probability regression task
struct LogisticRegression {
  XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return common::Sigmoid(x); }
  XGBOOST_DEVICE static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; }
  XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
    return predt - label;
  }
  XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float) {
    const float eps = 1e-16f;
    return fmaxf(predt * (1.0f - predt), eps);
  }
  static bst_float ProbToMargin(bst_float base_score) {
    CHECK(base_score > 0.0f && base_score < 1.0f)
        << "base_score must be in (0,1) for logistic loss, got: " << base_score;
    return -logf(1.0f / base_score - 1.0f);
  }
  static const char* LabelErrorMsg() { return "label must be in [0,1] for logistic regression"; }
  static const char* DefaultEvalMetric() { return "rmse"; }

  static const char* Name() { return "reg:logistic"; }

  static ObjInfo Info() { return ObjInfo::kRegression; }
};

// logistic loss for binary classification task
struct LogisticClassification : public LogisticRegression {
  static const char* DefaultEvalMetric() { return "logloss"; }
  static const char* Name() { return "binary:logistic"; }
  static ObjInfo Info() { return ObjInfo::kBinary; }
};

// logistic loss, but predict un-transformed margin
struct LogisticRaw : public LogisticRegression {
  XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return x; }
  XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
    predt = common::Sigmoid(predt);
    return predt - label;
  }
  XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float) {
    const float eps = 1e-16f;
    predt = common::Sigmoid(predt);
    return fmaxf(predt * (1.0f - predt), eps);
  }
  static bst_float ProbToMargin(bst_float base_score) { return base_score; }
  static const char* DefaultEvalMetric() { return "logloss"; }

  static const char* Name() { return "binary:logitraw"; }

  static ObjInfo Info() { return ObjInfo::kRegression; }
};

// gamma deviance loss.
class GammaDeviance {
 public:
  XGBOOST_DEVICE static float PredTransform(float x) { return std::exp(x); }
  XGBOOST_DEVICE static float ProbToMargin(float x) { return std::log(x); }
  XGBOOST_DEVICE static float FirstOrderGradient(float p, float y) {
    return 1.0f - y / p;
  }
  XGBOOST_DEVICE static float SecondOrderGradient(float p, float y) { return y / p; }
  static ObjInfo Info() { return ObjInfo::kRegression; }
  static const char* Name() { return "reg:gamma"; }
  static const char* DefaultEvalMetric() { return "gamma-deviance"; }
  XGBOOST_DEVICE static bool CheckLabel(float x) { return x > 0.0f; }
  static const char* LabelErrorMsg() { return "label must be positive for gamma regression."; }
};
}  // namespace xgboost::obj
#endif  // XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
